from typing import Dict, List, Optional, Tuple, cast

from fastapi import WebSocket
from langchain.agents import AgentExecutor
from langchain.callbacks import OpenAICallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool
from paperqa import Answer, Docs

from ..callback import RoutedWebsocketToolHandler, get_llm_name
from ..chains import search_chain
from ..docs import reduce_tokens
from ..models import AgentType, QueryRequest
from .tools import (
    DownstreamReferences,
    GatherEvidenceTool,
    GenerateAnswerTool,
    PaperSearchTool,
    SimilarPapersTool,
    UpstreamCitations,
    status,
)


def _make_tools(
    docs: Docs,
    answer: Answer,
    websocket: WebSocket,
    query: QueryRequest,
    use_experimental_tools: bool = False,
):
    tools: List[BaseTool] = []
    token_counts: Dict[str, list[int]] = dict()
    kwargs = dict(
        docs=docs,
        answer=answer,
        websocket=websocket,
        query=query,
        token_counts=token_counts,
    )
    tools.append(PaperSearchTool(**kwargs))
    tools.append(GatherEvidenceTool(**kwargs))
    tools.append(GenerateAnswerTool(**kwargs))
    tools.append(SimilarPapersTool(**kwargs))
    if use_experimental_tools:
        tools.append(UpstreamCitations(**kwargs))
        tools.append(DownstreamReferences(**kwargs))
    return tools


async def run_agent(
    docs: Docs,
    query: QueryRequest,
    websocket: WebSocket,
    llm=None,
    agent_type: str = "OpenAIFunctionsAgent",
    search_type: str = "google",
    use_experimental_tools: bool = False,
) -> Tuple[Answer, Dict[str, List[int]]]:
    question = query.query

    # For now, we use OpenAI functions agent by default (e.g. if llm not specified)
    # and set temperature at 0 for this task
    # TODO: Enable switching between OpenAI, Anthropic etc in .env configuration
    if llm is None or not get_llm_name(llm).startswith("gpt-4"):
        llm = ChatOpenAI(temperature=0.0, model="gpt-4", client=None)

    answer = Answer(
        question=question,
        dockey_filter=set(),
        max_concurrent=25 if "claude" in query.summary_llm else 5,
    )
    tools = _make_tools(
        docs,
        answer,
        websocket,
        query,
        use_experimental_tools,
    )

    # get paper tool
    search_tool: Optional[PaperSearchTool] = None
    possible_search_tools = list(filter(lambda x: x.name == "paper_search", tools))
    if len(possible_search_tools) > 0:
        search_tool = possible_search_tools[0]
        cast(PaperSearchTool, search_tool).search_type = search_type

    if agent_type == "fake":
        # seed with keyword search
        for search in await search_chain(answer.question, 3):
            await cast(PaperSearchTool, search_tool).arun(search)
        # just call each tool once
        for tool in tools:
            await tool.arun(answer.question)
            if tool.name == "gen_answer":
                break
        tokens_dict, _ = reduce_tokens(
            tools[0].token_counts, docs.llm, docs.summary_llm
        )

    else:
        # seed every time
        # we may not always use the same search tool
        if search_tool is not None:
            await cast(PaperSearchTool, search_tool).arun(answer.question)

        agent_instance = AgentExecutor.from_agent_and_tools(
            tools=tools,
            agent=AgentType.get_agent(agent_type).from_llm_and_tools(llm, tools),
            return_intermediate_steps=True,
            handle_parsing_errors=True,
        )
        await websocket.send_json({"c": "prompts", "prompts": docs.prompts.dict()})

        tool_callback = RoutedWebsocketToolHandler(websocket)
        cost_callback = OpenAICallbackHandler()

        call_response = await agent_instance.acall(
            f"Answer question: {question}. Search for papers, gather evidence, and answer. "
            "If you do not have enough evidence, you can search for more papers (preferred) or "
            "gather more evidence with a different phrase. You may rephrase or break-up the question in those steps. "
            "Once you have five or more pieces of evidence from multiple sources, or you have tried many times, "
            "call gen_answer tool. "
            "The current status of evidence/papers/cost is "
            + await status(docs, answer, dict(), websocket),
            callbacks=[tool_callback, cost_callback],
        )
        agent_trace = call_response["intermediate_steps"]
        agent_trace = [
            {"agent_action": str(x[0]), "response": str(x[1])} for x in agent_trace
        ]

        await websocket.send_json({"c": "agent-trace", "t": agent_trace})
        tokens_dict, _ = reduce_tokens(
            tools[0].token_counts, docs.llm, docs.summary_llm
        )

        # add agent costs to answre costs
        answer.cost += cost_callback.total_cost
        answer.token_counts["agent"] = cost_callback.total_tokens
        tokens_dict["agent"] = [
            cost_callback.prompt_tokens,
            cost_callback.completion_tokens,
        ]

    # DO NOT UNCOMMENT THIS!
    # I've put this here to remind myself
    # that I don't want to do this
    # because it is not as good
    # answer.answer = call_response["output"]

    return answer, tokens_dict
